import h5py
import torch
from pytorch3d.io import load_ply
import os
import glob
from IPython import embed
from collections import defaultdict
gt_pc_path = '/mnt/brain6/scratch/tiangel/partnet_scans-shape_mesh_scan_pc_rev'
input_pc_path = '/mnt/brain6/scratch/tiangel/partnet_scans-partial_mesh_scan_pc_rev'

shape_paths = glob.glob(os.path.join(gt_pc_path,'*.ply'))
shape_names = []
for p in shape_paths:
    shape_names.append(p.split('/')[-1].split('.')[0])
partial_paths = glob.glob(os.path.join(input_pc_path,'*/*.ply'))
shape_pcs = defaultdict(list)
for p in shape_paths:
    name = p.split('/')[-1].split('.')[0]
    shape_pcs[name].append(load_ply(p)[0].unsqueeze(0))
partial_pcs = defaultdict(list)
for p in partial_paths:
    name = p.split('/')[-2]
    if name not in shape_names:
        print('the partial pc doesnt have original pc', name)
        continue
    partial_pcs[name].append(load_ply(p)[0].unsqueeze(0))


common_keys = set(shape_pcs.keys()).intersection(partial_pcs.keys())
save_dict = {
    'shape_pcs': shape_pcs,
    'partial_pcs': partial_pcs,
}
save_file = h5py.File('/home/tiangel/datasets/completion_partnet_data.h5', 'w')
save_file['shape_pcs'] = shape_pcs
save_file['partial_pcs'] = partial_pcs
save_file.close()

# targets = torch.cat(shape_pcs).cpu().numpy()
inputs = []
for i in list(partial_pcs.keys()):
    inputs.append(torch.cat(partial_pcs[i]).unsqueeze(0))

embed()